{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ResNet-50 Inference with FINN on Alveo\n",
    "\n",
    "This notebook demonstrates the functionality of a FINN-based, full dataflow ResNet-50 implemented in Alveo U250. The characteristics of the network are the following:\n",
    " - residual blocks at 1-bit weights, 2/4-bit activations\n",
    " - first convolution and last (fully connected) layer use 8-bit weights\n",
    " - most parameters in BRAM/LUTRAM (no URAM used), FC parameters streamed from DDR\n",
    " - single DDR controller (DDR0) utilized for input, output and FC weight streaming\n",
    "\n",
    "We validate the network against ImageNet. We use the PYNQ APIs for retrieving and recording power information which is then displayed in real-time."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set up Accelerator with PYNQ\n",
    "We load the Alveo accelerator and print its memory-mapped registers:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pynq\n",
    "\n",
    "ol=pynq.Overlay(\"resnet50.xclbin\")\n",
    "accelerator=ol.resnet50_1\n",
    "print(accelerator.register_map)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we create a data buffer in the Alveo DDR memory to hold the weights of the Fully Connected Layer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "#allocate a buffer for FC weights, targeting the Alveo DDR Bank 0\n",
    "fcbuf = pynq.allocate((1000,2048), dtype=np.int8, target=ol.bank0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load the weight from a CSV file and push them to the accelerator buffer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#load Weights from file into the PYNQ buffer\n",
    "fcweights = np.genfromtxt(\"fcweights.csv\", delimiter=',', dtype=np.int8)\n",
    "#csv reader erroneously adds one extra element to the end, so remove, then reshape\n",
    "fcweights = fcweights[:-1].reshape(1000,2048)\n",
    "fcbuf[:] = fcweights\n",
    "\n",
    "#Move the data to the Alveo DDR\n",
    "fcbuf.sync_to_device()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Single Image Inference\n",
    "In this example we perform inference on each of the images in a `pictures` folder and display the top predicted class overlaid onto the image. The code assumes the existence of this `pictures` folder, where you should put the images you want to classificate. There is no restriction on the images that you can use."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import shutil\n",
    "import wget\n",
    "import os\n",
    "import glob\n",
    "from itertools import chain\n",
    "import cv2\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "image_list = list(chain.from_iterable([glob.glob('pictures/*.%s' % ext) for ext in [\"jpg\",\"gif\",\"png\",\"tga\"]]))\n",
    "\n",
    "#get imagenet classes from file\n",
    "import pickle\n",
    "classes = pickle.load(open(\"labels.pkl\",'rb'))\n",
    "\n",
    "def infer_once(filename):\n",
    "    inbuf = pynq.allocate((224,224,3), dtype=np.int8, target=ol.bank0)\n",
    "    outbuf = pynq.allocate((5,), dtype=np.uint32, target=ol.bank0)\n",
    "\n",
    "    #preprocess image\n",
    "    img = cv2.resize(cv2.imread(filename), (224,224))\n",
    "\n",
    "    #transfer to accelerator\n",
    "    inbuf[:] = img\n",
    "    inbuf.sync_to_device()\n",
    "    \n",
    "    #do inference\n",
    "    accelerator.call(inbuf, outbuf, fcbuf, 1)\n",
    "\n",
    "    #get results\n",
    "    outbuf.sync_from_device()\n",
    "    results = np.copy(outbuf)\n",
    "    return results\n",
    "\n",
    "inf_results = []\n",
    "for img in image_list:\n",
    "    inf_output = infer_once(img)\n",
    "    inf_result = [classes[i] for i in inf_output]\n",
    "    inf_results.append(inf_result)\n",
    "\n",
    "plt.figure(figsize=(20,10))\n",
    "columns = 3\n",
    "for i, image in enumerate(image_list):\n",
    "    plt.subplot(len(image_list) / columns + 1, columns, i + 1)\n",
    "    top_class = inf_results[i][0].split(',', 1)[0]\n",
    "    display_image = cv2.cvtColor(cv2.resize(cv2.imread(image),(224,224)), cv2.COLOR_BGR2RGB)\n",
    "    plt.imshow(cv2.putText(display_image, top_class, (10,20), cv2.FONT_HERSHEY_TRIPLEX, 0.7, (255,255,255)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot Accelerator Board Power with PYNQ\n",
    "We first set up data acquisition using PYNQ's PMBus API"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import plotly\n",
    "import plotly.graph_objs as go\n",
    "import pandas as pd\n",
    "from pynq import pmbus\n",
    "import time\n",
    "\n",
    "rails = pmbus.get_xrt_sysfs_rails(pynq.pl_server.Device.active_device)\n",
    "\n",
    "#We create a recorder monitoring the three rails that have power measurement on Alveo. \n",
    "#Total board power is obtained by summing together the PCI Express and Auxilliary 12V rails. \n",
    "#While some current is also drawn over the PCIe 5V rail this is negligible compared to the 12V rails and isn't recorded. \n",
    "#We also measure the VCC_INT power which is the primary supply to the FPGA.\n",
    "\n",
    "recorder = pmbus.DataRecorder(rails[\"12v_aux\"].power,\n",
    "                              rails[\"12v_pex\"].power,\n",
    "                              rails[\"vccint\"].power)\n",
    "\n",
    "f = recorder.frame\n",
    "\n",
    "powers = pd.DataFrame(index=f.index)\n",
    "powers['board_power'] = f['12v_aux_power'] + f['12v_pex_power']\n",
    "powers['fpga_power'] = f['vccint_power']\n",
    "\n",
    "#Now we need to specify the layout for the graph. In this case it will be a simple Line/Scatter plot, \n",
    "#autoranging on both axes with the Y axis having 0 at the bottom.\n",
    "layout = {\n",
    "    'xaxis': {\n",
    "        'title': 'Time (s)'\n",
    "    },\n",
    "    'yaxis': {\n",
    "        'title': 'Power (W)',\n",
    "        'rangemode': 'tozero',\n",
    "        'autorange': True\n",
    "    }\n",
    "}\n",
    "\n",
    "#Plotly expects data in a specific format, namely an array of plotting objects. \n",
    "#This helper function will update the data in a plot based. \n",
    "#Th e `DataRecorder` stores the recording in a Pandas dataframe object with a time-based index. \n",
    "#This makes it easy to pull out the results for a certain time range and compute a moving average. \n",
    "#In this case we are going to give a 5-second moving average of the results as well as the raw input.\n",
    "def update_data(frame, start, end, plot):\n",
    "    ranged = frame[start:end]\n",
    "    average_ranged = frame[start-pd.tseries.offsets.Second(5):end]\n",
    "    rolling = (average_ranged['12v_aux_power'] + average_ranged['12v_pex_power']).rolling(\n",
    "        pd.tseries.offsets.Second(5)\n",
    "    ).mean()[ranged.index]\n",
    "    powers = pd.DataFrame(index=ranged.index)\n",
    "    powers['board_power'] = ranged['12v_aux_power'] + ranged['12v_pex_power']\n",
    "    powers['rolling'] = rolling\n",
    "    data = [\n",
    "        go.Scatter(x=powers.index, y=powers['board_power'], name=\"Board Power\"),\n",
    "        go.Scatter(x=powers.index, y=powers['rolling'], name=\"5 Second Avg\")\n",
    "    ]\n",
    "    plot.update(data=data)\n",
    "    \n",
    "#Next we create an show the plot object, initially there will be no data to display but this plot will be updated after we start the recording. \n",
    "#Once the plot is running it is possible to right click on it to pop out the graph into a separate window.\n",
    "plot = go.FigureWidget(layout=layout)\n",
    "plot"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we create a dynamically-updating power graph:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "recorder.record(0.1)\n",
    "\n",
    "#In order to continue updating the graph we need a thread running in the background. \n",
    "#The following thread will call our update function twice a second to display the most recently collected minute of data.\n",
    "do_update = True\n",
    "\n",
    "def thread_func():\n",
    "    while do_update:\n",
    "        now = pd.Timestamp.fromtimestamp(time.time())\n",
    "        past = now - pd.tseries.offsets.Second(60)\n",
    "        update_data(recorder.frame, past, now, plot)\n",
    "        time.sleep(0.5)\n",
    "\n",
    "from threading import Thread\n",
    "t = Thread(target=thread_func)\n",
    "t.start()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To manually stop the power graph:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "do_update = False\n",
    "recorder.stop()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Synthetic Throughput Test\n",
    "We execute inference of a configurable-size batch of images, without data movement. We measure the latency and throughput. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ipywidgets as widgets\n",
    "from IPython.display import clear_output\n",
    "\n",
    "bs = widgets.IntSlider(\n",
    "    value=128,\n",
    "    min=1,\n",
    "    max=1000,\n",
    "    step=1,\n",
    "    description='Batch Size:',\n",
    "    disabled=False,\n",
    "    continuous_update=False,\n",
    "    orientation='horizontal',\n",
    "    readout=True,\n",
    "    readout_format='d'\n",
    ")\n",
    "fps = widgets.IntProgress(min=0, max=2500, description='FPS: ')\n",
    "latency = widgets.FloatProgress(min=0, max=0.1, description='Latency (ms): ')\n",
    "\n",
    "button = widgets.Button(description='Stop')\n",
    "stop_running = False\n",
    "\n",
    "def on_button_clicked(_):\n",
    "    global stop_running\n",
    "    stop_running = True\n",
    "            \n",
    "# linking button and function together using a button's method\n",
    "button.on_click(on_button_clicked)\n",
    "\n",
    "out_fps = widgets.Text()\n",
    "out_latency = widgets.Text()\n",
    "\n",
    "ui_top = widgets.HBox([button, bs])\n",
    "ui_bottom = widgets.HBox([fps, out_fps, latency, out_latency])\n",
    "ui = widgets.VBox([ui_top, ui_bottom])\n",
    "display(ui)\n",
    "\n",
    "import time\n",
    "import threading\n",
    "\n",
    "def benchmark_synthetic():\n",
    "    import pynq\n",
    "    ibuf = pynq.allocate((1000,3,224,224), dtype=np.int8, target=ol.bank0)\n",
    "    obuf = pynq.allocate((1000,5), dtype=np.uint32, target=ol.bank0)\n",
    "\n",
    "    while True:\n",
    "        if stop_running:\n",
    "            print(\"Stopping\")\n",
    "            return\n",
    "        duration = time.monotonic()\n",
    "        accelerator.call(ibuf, obuf, fcbuf, bs.value)\n",
    "        duration = time.monotonic() - duration\n",
    "        fps.value = int(bs.value/duration)\n",
    "        latency.value = duration\n",
    "        out_fps.value = str(fps.value)\n",
    "        out_latency.value = '%.2f' % (duration * 1000)\n",
    "        \n",
    "\n",
    "t = threading.Thread(target=benchmark_synthetic)\n",
    "t.start()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
